{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training the Ultra-Fast Lane Detection network\n",
    "\n",
    "This Jupyter Notebook trains the Ultra-Fast Lane Detection network using LLAMAS data.\n",
    "\n",
    "You need the LLAMAS dataset (adjust `BASE_PATH` below) and the LLAMAS python library. Just clone it to the same directory as this notebook:\n",
    "```\n",
    "git clone https://github.com/karstenBehrendt/unsupervised_llamas.git\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.ultrafast import UltraFastNet\n",
    "\n",
    "NUM_LANES = 2\n",
    "INPUT_SIZE = (288, 800)\n",
    "CLS_SHAPE = (20, 100, NUM_LANES)\n",
    "\n",
    "INPUT_SHAPE = (*INPUT_SIZE, 3)\n",
    "OUTPUT_SHAPE = CLS_SHAPE\n",
    "\n",
    "DTYPE = tf.float32\n",
    "\n",
    "ufm = UltraFastNet(num_lanes=NUM_LANES, size=INPUT_SIZE, cls_dim=CLS_SHAPE, use_aux=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ufm.compile(optimizer='adam', loss=losses.ultrafast_loss)\n",
    "ufm.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TAKE_LLAMAS_MAX = 1000\n",
    "TAKE_CULANE_MAX = 0\n",
    "TAKE_MIWULA_MAX = 0 # all of them\n",
    "\n",
    "# Ratio of validation data vs training data\n",
    "VALID_RATIO = 0.3\n",
    "\n",
    "BATCH_SIZE = 8 # tune to available GPU memory\n",
    "PREFETCH_SIZE = 200 # tune to available memory\n",
    "\n",
    "#USE_DATA_AUGMENTATION = True\n",
    "\n",
    "EPOCHS = 20 # training epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def display(display_list):\n",
    "    plt.figure(figsize=(15, 15))\n",
    "    \n",
    "    title = ['Input Image', 'True Mask', 'Predicted Mask']\n",
    "\n",
    "    for i in range(len(display_list)):\n",
    "        plt.subplot(1, len(display_list), i+1)\n",
    "        plt.title(title[i])\n",
    "        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TODO: This LLAMAS code was just copied from a previous image segmentation experiment, needs to be adopted to Ultra Fast network. But kinda works anyway"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from unsupervised_llamas.label_scripts.spline_creator import get_horizontal_values_for_four_lanes\n",
    "\n",
    "BASE_PATH = '/Volumes/tank_data/Archiv/ML/llamas/labels/'\n",
    "\n",
    "LLAMAS_SHAPE = (717, 1276, 3)\n",
    "#MASK_SHAPE = (712, 1272, 3)\n",
    "\n",
    "IMAGE_SHAPE = INPUT_SHAPE\n",
    "MASK_SHAPE = OUTPUT_SHAPE\n",
    "\n",
    "PARALLEL_JOBS = 4# tf.data.experimental.AUTOTUNE\n",
    "\n",
    "\n",
    "def resize_img(img, shape):\n",
    "    w = img.shape[1]\n",
    "    h = img.shape[0]\n",
    "    \n",
    "    ratio = shape[1] / shape[0]\n",
    "    \n",
    "    tgt = img[h-int(w/ratio):h,0:w]\n",
    "    tgt = cv2.resize(tgt, (shape[1], shape[0]))\n",
    "    return tgt\n",
    "\n",
    "def draw_spline(img, points, color):\n",
    "    n = np.array(points)\n",
    "    idx = np.arange(0, len(n)).reshape(-1,1)\n",
    "    idx = idx[n >= 0]\n",
    "    n = n[n >= 0]\n",
    "    n = n.reshape(-1,1)\n",
    "    pts = np.hstack((n, idx))\n",
    "    \n",
    "    cv2.polylines(img, np.int32([pts]), isClosed=False, color=color, thickness=10, lineType=cv2.LINE_8)\n",
    "    \n",
    "    del n\n",
    "    del idx\n",
    "    del pts\n",
    "\n",
    "def generate_mask(file):\n",
    "    file = file.numpy().decode(\"utf-8\")\n",
    "    mask = np.zeros(LLAMAS_SHAPE, dtype='uint8')\n",
    "    \n",
    "    lines = get_horizontal_values_for_four_lanes(file)\n",
    "    l0 = lines[1]\n",
    "    r0 = lines[2]\n",
    "    \n",
    "    draw_spline(mask, l0, (0,255,0))\n",
    "    draw_spline(mask, r0, (255,0,0))\n",
    "    \n",
    "    mask = resize_img(mask, MASK_SHAPE)\n",
    "    return mask[:,:,0:2]\n",
    "\n",
    "def read_image(file):\n",
    "    file = file.numpy().decode(\"utf-8\")\n",
    "    with open(file, 'r') as fp:\n",
    "        meta = json.load(fp)\n",
    "        \n",
    "    image_path = os.path.dirname(file)\n",
    "    image_path = image_path.replace('/labels/', '/color_images/')\n",
    "    img_name = meta['image_name'] + '_color_rect.png'\n",
    "    fname = os.path.join(image_path, img_name)\n",
    "    \n",
    "    img = cv2.imread(fname)\n",
    "    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "    return resize_img(img, IMAGE_SHAPE)\n",
    "\n",
    "def process_json(json):\n",
    "    img = tf.py_function(read_image, [json], Tout=DTYPE)\n",
    "    mask = tf.py_function(generate_mask, [json], Tout=DTYPE)\n",
    "    \n",
    "    img = tf.reshape(img, shape=IMAGE_SHAPE)\n",
    "    mask = tf.reshape(mask, shape=MASK_SHAPE)\n",
    "    return img, mask\n",
    "\n",
    "SHUFFLE_SIZE = 100000\n",
    "\n",
    "llamas_train_ds = tf.data.Dataset.list_files(os.path.join(BASE_PATH, 'train', '*/*.json'))\n",
    "llamas_train_ds = llamas_train_ds.shuffle(SHUFFLE_SIZE)\n",
    "\n",
    "llamas_train_ds = llamas_train_ds.take(TAKE_LLAMAS_MAX)\n",
    "LLAMAS_TRAIN_LEN = len(list(llamas_train_ds))\n",
    "print('LLAMAS training images: %d' % LLAMAS_TRAIN_LEN)\n",
    "\n",
    "llamas_train_ds = llamas_train_ds.map(process_json, num_parallel_calls=PARALLEL_JOBS)\n",
    "\n",
    "llamas_valid_ds = tf.data.Dataset.list_files(os.path.join(BASE_PATH, 'valid', '*/*.json'))\n",
    "llamas_valid_ds = llamas_valid_ds.shuffle(SHUFFLE_SIZE)\n",
    "\n",
    "llamas_valid_ds = llamas_valid_ds.take(TAKE_LLAMAS_MAX)\n",
    "LLAMAS_VALID_LEN = len(list(llamas_valid_ds))\n",
    "print('LLAMAS validation images: %d' % LLAMAS_VALID_LEN)\n",
    "\n",
    "llamas_valid_ds = llamas_valid_ds.map(process_json, num_parallel_calls=PARALLEL_JOBS)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for img, mask in llamas_valid_ds.take(3):\n",
    "    m = tf.zeros((20,100,1), dtype=DTYPE)\n",
    "    m = tf.concat([mask, m], -1)\n",
    "    display([img, m])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "\n",
    "json_glob = glob.glob(os.path.join(BASE_PATH, 'train', '*/*.json'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "json_file = json_glob[0]\n",
    "lines = get_horizontal_values_for_four_lanes(json_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CLS_SHAPE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = llamas_train_ds.batch(BATCH_SIZE)\n",
    "train_ds = train_ds.prefetch(PREFETCH_SIZE)\n",
    "\n",
    "valid_ds = llamas_valid_ds.take(int(TAKE_LLAMAS_MAX * 0.3)).batch(BATCH_SIZE)\n",
    "valid_ds = valid_ds.prefetch(PREFETCH_SIZE)\n",
    "\n",
    "ufm.fit(train_ds, validation_data=valid_ds, epochs=EPOCHS, verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for img, mask in llamas_valid_ds.take(5):\n",
    "    s = time.time()\n",
    "    pred = ufm.predict(tf.expand_dims(img, axis=0))\n",
    "    e = time.time()\n",
    "    print(e-s)\n",
    "    pred = pred[0]\n",
    "    m = tf.zeros((20,100,1), dtype=DTYPE)\n",
    "    m = tf.concat([pred, m], -1)\n",
    "    \n",
    "    m2 = tf.zeros((20,100,1), dtype=DTYPE)\n",
    "    m2 = tf.concat([mask, m2], -1)\n",
    "    \n",
    "    display([img, m2, m])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ufm.save('ultrafast.tf', save_format='tf')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Converting to TF Lite for use with Google EdgeTPU\n",
    "\n",
    "This code converts and quantizes the network for use with a coral stick. The resulting `ultrafast_quant.tflite` needs to be compiled with the EdgeTPU compiler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def representative_data_gen():\n",
    "    for img, mask in llamas_valid_ds:\n",
    "        yield [tf.expand_dims(img, axis=0)]\n",
    "\n",
    "converter = tf.lite.TFLiteConverter.from_keras_model(ufm)\n",
    "# This enables quantization\n",
    "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
    "converter.target_spec.supported_types = [tf.int8]\n",
    "# This ensures that if any ops can't be quantized, the converter throws an error\n",
    "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n",
    "# These set the input and output tensors to uint8 (added in r2.3)\n",
    "converter.inference_input_type = tf.uint8\n",
    "converter.inference_output_type = tf.uint8\n",
    "# And this sets the representative dataset so we can quantize the activations\n",
    "converter.representative_dataset = representative_data_gen\n",
    "tflite_model = converter.convert()\n",
    "\n",
    "with open('ultrafast_quant.tflite', 'wb') as f:\n",
    "    f.write(tflite_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test the TF Lite model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interpreter = tf.lite.Interpreter('ultrafast_quant.tflite')\n",
    "interpreter.allocate_tensors()\n",
    "\n",
    "def set_input(interpreter, tensor):\n",
    "    input_details = interpreter.get_input_details()[0]\n",
    "    print(input_details)\n",
    "    tensor_index = input_details['index']\n",
    "    input_tensor = interpreter.tensor(tensor_index)()[0]\n",
    "    \n",
    "    scale, zero_point = input_details['quantization']\n",
    "    \n",
    "    input_tensor[:, :] = tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for img, mask in llamas_valid_ds.take(1):\n",
    "    set_input(interpreter, img)\n",
    "    interpreter.invoke()\n",
    "    \n",
    "    output_details = interpreter.get_output_details()[0]\n",
    "    print(output_details)\n",
    "    output = interpreter.get_tensor(output_details['index'])\n",
    "    # Outputs from the TFLite model are uint8, so we dequantize the results:\n",
    "    scale, zero_point = output_details['quantization']\n",
    "    #output = scale * (output - zero_point)\n",
    "    \n",
    "    m = tf.zeros((20,100,1), dtype=DTYPE)\n",
    "    m = tf.concat([output[0], m], -1)\n",
    "    \n",
    "    m2 = tf.zeros((20,100,1), dtype=DTYPE)\n",
    "    m2 = tf.concat([mask, m2], -1)\n",
    "    \n",
    "    display([img, m2, m])\n",
    "    \n",
    "    #print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test loading a saved TF model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ufml = tf.keras.models.load_model('ultrafast.tf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ufml.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
